import numpy as np
import matplotlib.pyplot as plt
from generator.data_stream import MCMC
from optimization.mnist import MNIST, MNIST2
import pickle
import argparse


def main(args):
    id = "TUNE_PURELS_lr" + str(args.lr) +"_bs" +str(args.bs)+"_r"+str(args.r)
    print(id)

    all_acc = []
    all_loss = []
    for _ in range(args.num_rep):
        print(_)

        if args.r < 0:
            DM = MNIST2(args.bs, args.lr)
        else:
            G = MCMC(r=args.r, batch_size=args.bs, dim=1)
            DM = MNIST(G, args.lr * args.bs )

        for i in range(args.steps):
            DM.eval(eval_loss=True)
            DM.step()
        all_acc.append(DM.acc_hist)
        all_loss.append(DM.loss_hist)

    with open('accuracy_' + id+ '.pickle', 'wb') as handle:
        pickle.dump(all_acc, handle, protocol=pickle.HIGHEST_PROTOCOL)
    with open('loss_' + id+ '.pickle', 'wb') as handle:
        pickle.dump(all_loss, handle, protocol=pickle.HIGHEST_PROTOCOL)

    print("Training completed.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Dependent Data')
    parser.add_argument('--steps', type=int, default=1000, help='Number of steps')
    parser.add_argument('--lr', type=float, default=10e-4,
                        help='Learning rate')
    parser.add_argument('--bs', type=int, default=100, help='Batch size')
    parser.add_argument('--num_rep', type=int, default=10, help='Number of repeating experiments')
    parser.add_argument('--r', type=float, default=1.0, help='Strength of data dependence')

    parse_args = parser.parse_args()

    main(args=parse_args)